{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp models.misc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Miscellaneous\n",
    "\n",
    "> This contains a set of experiments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from tsai.imports import *\n",
    "from tsai.models.layers import *\n",
    "from tsai.models.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class InputWrapper(Module):\n",
    "    def __init__(self, arch, c_in, c_out, seq_len, new_c_in=None, new_seq_len=None, **kwargs):\n",
    "\n",
    "        new_c_in = ifnone(new_c_in, c_in)\n",
    "        new_seq_len = ifnone(new_seq_len, seq_len)\n",
    "        self.new_shape = c_in != new_c_in or seq_len != new_seq_len\n",
    "        if self.new_shape:\n",
    "            layers = []\n",
    "            if c_in != new_c_in: \n",
    "                lin = nn.Linear(c_in, new_c_in)\n",
    "                nn.init.constant_(lin.weight, 0)\n",
    "                layers += [Transpose(1,2), lin, Transpose(1,2)]\n",
    "                lin2 = nn.Linear(seq_len, new_seq_len)\n",
    "                nn.init.constant_(lin2.weight, 0)\n",
    "                layers += [lin2]\n",
    "            self.new_shape_fn = nn.Sequential(*layers)\n",
    "        self.model = build_ts_model(arch, c_in=new_c_in, c_out=c_out, seq_len=new_seq_len, **kwargs)\n",
    "    def forward(self, x):\n",
    "        if self.new_shape: x = self.new_shape_fn(x) \n",
    "        return self.model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.models.TST import *\n",
    "xb = torch.randn(16, 1, 1000)\n",
    "model = InputWrapper(TST, 1, 4, 1000, 10, 224)\n",
    "test_eq(model.to(xb.device)(xb).shape, (16,4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class ResidualWrapper(Module):\n",
    "    def __init__(self, model):\n",
    "        self.model = model\n",
    "\n",
    "    def forward(self, x):\n",
    "        return x[..., -1] + self.model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "# RecursiveWrapper has not proved to be very useful so far."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class RecursiveWrapper(Module):\n",
    "    def __init__(self, model, n_steps, anchored=False):\n",
    "        self.model, self.n_steps, self.anchored = model, n_steps, anchored\n",
    "    def forward(self, x):\n",
    "        preds = []\n",
    "        for _ in range(self.n_steps): \n",
    "            pred = self.model(x)\n",
    "            preds.append(pred)\n",
    "            if x.ndim != pred.ndim: pred = pred[:, np.newaxis]\n",
    "            x = torch.cat((x if self.anchored else x[..., 1:], pred), -1)\n",
    "        return torch.cat(preds, -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.models.TST import *\n",
    "xb = torch.randn(16, 1, 20)\n",
    "model = RecursiveWrapper(TST(1, 1, 20), 5)\n",
    "test_eq(model.to(xb.device)(xb).shape, (16, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from tsai.imports import create_scripts\n",
    "from tsai.export import get_nb_name\n",
    "nb_name = get_nb_name()\n",
    "create_scripts(nb_name);"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
